#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@author:hengk
@contact: hengk@foxmail.com
@datetime:2019-11-02 11:25
"""
import yaml
from easydict import EasyDict
from torch.optim.lr_scheduler import _LRScheduler

def load_config(config_path):
    with open(config_path, mode='r', encoding='utf-8') as f:
        cfg = yaml.load(f.read(), Loader=yaml.FullLoader)
        cfg = EasyDict(cfg)
    return cfg


class WarmUpLR(_LRScheduler):
    """warmup_training learning rate scheduler
    Args:
        optimizer: optimzier(e.g. SGD)
        total_iters: totoal_iters of warmup phase
    """

    def __init__(self, optimizer, total_iters, last_epoch=-1):
        self.total_iters = total_iters
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        """we will use the first m batches, and set the learning
        rate to base_lr * m / total_iters
        """
        return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]